"""
Phase 7: OECD Suppression Story & Eurozone Split
===================================================
Section 1: Decompose OECD Z→CA into capital-deepening (surplus) and
           residual (deficit) channels. Show full-sample amplification
           is driven by non-OECD composition.
Section 2: Eurozone vs OECD floaters — does K/Y mediation differ
           under fixed vs floating exchange rates?
Section 3: Cross-tabulation summary table.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())

CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
# Match trilemma paper's 3-control spec for eurozone regressions to avoid
# 34% sample attrition from log_rel_opw/kaopen missingness
CONTROLS_EZ = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
Z_VARS = ['Z_1', 'Z_2', 'Z_3']


# ── Helpers ──────────────────────────────────────────────────────────

def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    gls.fit(y, X, sub['iso3'].values, sub['year'].values)

    names = feature_names if feature_names else x_vars
    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(names):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def compute_attenuation(baseline, mediated, var='Z_1'):
    """Compute attenuation percentage for a given variable."""
    base_key = f'{var}_coef'
    if base_key not in baseline or base_key not in mediated:
        return None
    b_base = baseline[base_key]
    b_med = mediated[base_key]
    if abs(b_base) < 1e-10:
        return None
    atten = (b_base - b_med) / b_base * 100
    return atten


def make_eurozone_sample(df):
    """Build eurozone post-join sample."""
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        ez_rows.append(df[mask])
    return pd.concat(ez_rows, ignore_index=True)


def make_oecd_floaters(df):
    """Build OECD non-eurozone sample (floaters)."""
    oecd_non_ez = [c for c in OECD if c not in EUROZONE_ISO3]
    return df[df['iso3'].isin(oecd_non_ez)].copy()


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 7: OECD SUPPRESSION STORY & EUROZONE SPLIT")
    print("=" * 70)

    df = pd.read_csv(DATA / "automation_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    df_oecd = df[df['iso3'].isin(OECD)].copy()
    df_nonoecd = df[~df['iso3'].isin(OECD)].copy()
    print(f"  OECD: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")
    print(f"  Non-OECD: {len(df_nonoecd)} obs, {df_nonoecd['iso3'].nunique()} countries")

    # ══════════════════════════════════════════════════════════════════
    # SECTION 1: OECD SUPPRESSION DECOMPOSITION
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("SECTION 1: OECD SUPPRESSION DECOMPOSITION")
    print("=" * 60)

    results_s1 = []

    # M1: Z → CA, full sample (baseline)
    print("\n--- M1: Z → CA, Full Sample ---")
    m1 = run_panel_gls(df, 'ca_gdp', Z_VARS + CONTROLS, 'M1: Full')
    if m1: results_s1.append(m1)

    # M2: Z → CA, OECD only
    print("\n--- M2: Z → CA, OECD ---")
    m2 = run_panel_gls(df_oecd, 'ca_gdp', Z_VARS + CONTROLS, 'M2: OECD')
    if m2: results_s1.append(m2)

    # M3: Z → CA, non-OECD only
    print("\n--- M3: Z → CA, Non-OECD ---")
    m3 = run_panel_gls(df_nonoecd, 'ca_gdp', Z_VARS + CONTROLS, 'M3: Non-OECD')
    if m3: results_s1.append(m3)

    # M4: Z + K/Y → CA, full sample
    print("\n--- M4: Z + K/Y → CA, Full Sample ---")
    m4 = run_panel_gls(df, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS,
                       'M4: Full+K/Y')
    if m4: results_s1.append(m4)

    # M5: Z + K/Y → CA, OECD
    print("\n--- M5: Z + K/Y → CA, OECD ---")
    m5 = run_panel_gls(df_oecd, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS,
                       'M5: OECD+K/Y')
    if m5: results_s1.append(m5)

    # M6: Z + K/Y → CA, non-OECD
    print("\n--- M6: Z + K/Y → CA, Non-OECD ---")
    m6 = run_panel_gls(df_nonoecd, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS,
                       'M6: Non-OECD+K/Y')
    if m6: results_s1.append(m6)

    write_table(results_s1, "phase7_oecd_suppression.md",
                "OECD Suppression Decomposition: Z → CA with and without Capital Intensity")

    # Attenuation report
    print("\n--- Attenuation Summary ---")
    pairs = [('Full', m1, m4), ('OECD', m2, m5), ('Non-OECD', m3, m6)]
    attenuation_results = {}
    for label, base, med in pairs:
        if base and med:
            atten = compute_attenuation(base, med, 'Z_1')
            b_base = base.get('Z_1_coef', np.nan)
            b_med = med.get('Z_1_coef', np.nan)
            if atten is not None:
                print(f"  {label}: Z₁ = {b_base:.4f} → {b_med:.4f} "
                      f"(attenuation: {atten:.1f}%)")
                attenuation_results[label] = {
                    'baseline': b_base,
                    'mediated': b_med,
                    'baseline_p': base.get('Z_1_p', np.nan),
                    'mediated_p': med.get('Z_1_p', np.nan),
                    'attenuation': atten,
                }

    # Additional: old_dep × capital_intensity interaction (OECD)
    print("\n--- OECD: old_dep × capital_intensity interaction ---")
    if 'old_dep' in df_oecd.columns:
        df_oecd_int = df_oecd.copy()
        df_oecd_int['old_dep_x_ki'] = df_oecd_int['old_dep'] * df_oecd_int['capital_intensity']
        int_vars = Z_VARS + ['capital_intensity', 'old_dep_x_ki'] + CONTROLS
        r_int = run_panel_gls(df_oecd_int, 'ca_gdp', int_vars,
                              'OECD: Z+K/Y+old_dep×K/Y')
        if r_int:
            p_int = r_int.get('old_dep_x_ki_p', np.nan)
            c_int = r_int.get('old_dep_x_ki_coef', np.nan)
            print(f"\n  old_dep × K/Y interaction: {c_int:.4f} (p={p_int:.4f})")

    # ══════════════════════════════════════════════════════════════════
    # SECTION 2: EUROZONE VS OECD FLOATERS
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("SECTION 2: EUROZONE VS OECD FLOATERS")
    print("=" * 60)

    ez_df = make_eurozone_sample(df)
    float_df = make_oecd_floaters(df)
    print(f"  Eurozone (post-join): {len(ez_df)} obs, {ez_df['iso3'].nunique()} countries")
    print(f"  OECD floaters: {len(float_df)} obs, {float_df['iso3'].nunique()} countries")

    results_s2 = []

    # --- Eurozone mediation steps ---
    print("\n--- Eurozone: Baron & Kenny Steps ---")

    # EZ Step 1: Z → capital_intensity
    ez_s1 = run_panel_gls(ez_df, 'capital_intensity', Z_VARS + CONTROLS_EZ,
                          'EZ S1: Z→K/Y')
    if ez_s1: results_s2.append(ez_s1)

    # EZ Step 2: Z → CA (baseline)
    ez_s2 = run_panel_gls(ez_df, 'ca_gdp', Z_VARS + CONTROLS_EZ,
                          'EZ S2: Z→CA')
    if ez_s2: results_s2.append(ez_s2)

    # EZ Step 3: Z + K/Y → CA (mediation)
    ez_s3 = run_panel_gls(ez_df, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS_EZ,
                          'EZ S3: Z+K/Y→CA')
    if ez_s3: results_s2.append(ez_s3)

    # --- Floaters mediation steps ---
    print("\n--- OECD Floaters: Baron & Kenny Steps ---")

    # Float Step 1: Z → capital_intensity
    fl_s1 = run_panel_gls(float_df, 'capital_intensity', Z_VARS + CONTROLS_EZ,
                          'Float S1: Z→K/Y')
    if fl_s1: results_s2.append(fl_s1)

    # Float Step 2: Z → CA (baseline)
    fl_s2 = run_panel_gls(float_df, 'ca_gdp', Z_VARS + CONTROLS_EZ,
                          'Float S2: Z→CA')
    if fl_s2: results_s2.append(fl_s2)

    # Float Step 3: Z + K/Y → CA (mediation)
    fl_s3 = run_panel_gls(float_df, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS_EZ,
                          'Float S3: Z+K/Y→CA')
    if fl_s3: results_s2.append(fl_s3)

    write_table(results_s2, "phase7_eurozone_vs_floaters.md",
                "Eurozone vs OECD Floaters: Capital Intensity Mediation")

    # Eurozone and floater attenuation
    print("\n--- Eurozone vs Floater Attenuation ---")
    ez_atten = None
    fl_atten = None
    if ez_s2 and ez_s3:
        ez_atten = compute_attenuation(ez_s2, ez_s3, 'Z_1')
        if ez_atten is not None:
            print(f"  Eurozone: Z₁ = {ez_s2['Z_1_coef']:.4f} → {ez_s3['Z_1_coef']:.4f} "
                  f"(attenuation: {ez_atten:.1f}%)")
            attenuation_results['Eurozone'] = {
                'baseline': ez_s2['Z_1_coef'],
                'mediated': ez_s3['Z_1_coef'],
                'baseline_p': ez_s2.get('Z_1_p', np.nan),
                'mediated_p': ez_s3.get('Z_1_p', np.nan),
                'attenuation': ez_atten,
            }
    if fl_s2 and fl_s3:
        fl_atten = compute_attenuation(fl_s2, fl_s3, 'Z_1')
        if fl_atten is not None:
            print(f"  Floaters: Z₁ = {fl_s2['Z_1_coef']:.4f} → {fl_s3['Z_1_coef']:.4f} "
                  f"(attenuation: {fl_atten:.1f}%)")
            attenuation_results['OECD Float'] = {
                'baseline': fl_s2['Z_1_coef'],
                'mediated': fl_s3['Z_1_coef'],
                'baseline_p': fl_s2.get('Z_1_p', np.nan),
                'mediated_p': fl_s3.get('Z_1_p', np.nan),
                'attenuation': fl_atten,
            }

    # --- Interaction model: full OECD with eurozone dummies ---
    print("\n--- OECD Interaction Model: Z₁×eurozone, K/Y×eurozone ---")
    df_oecd_ez = df_oecd.copy()
    # Mark eurozone membership (post-join only)
    df_oecd_ez['eurozone'] = 0
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df_oecd_ez['iso3'] == iso3) & (df_oecd_ez['year'] >= join_yr)
        df_oecd_ez.loc[mask, 'eurozone'] = 1

    df_oecd_ez['Z1_x_ez'] = df_oecd_ez['Z_1'] * df_oecd_ez['eurozone']
    df_oecd_ez['ki_x_ez'] = df_oecd_ez['capital_intensity'] * df_oecd_ez['eurozone']

    int_vars_ez = Z_VARS + ['capital_intensity', 'eurozone', 'Z1_x_ez', 'ki_x_ez'] + CONTROLS_EZ
    r_int_ez = run_panel_gls(df_oecd_ez, 'ca_gdp', int_vars_ez,
                             'OECD: Z+K/Y+EZ interactions')
    if r_int_ez:
        results_s2_int = [r_int_ez]
        # Also run without K/Y for comparison
        int_vars_ez_noky = Z_VARS + ['eurozone', 'Z1_x_ez'] + CONTROLS_EZ
        r_int_ez_noky = run_panel_gls(df_oecd_ez, 'ca_gdp', int_vars_ez_noky,
                                      'OECD: Z+EZ interactions')
        if r_int_ez_noky:
            results_s2_int.insert(0, r_int_ez_noky)

    # --- Core vs Periphery eurozone split ---
    print("\n" + "=" * 60)
    print("SECTION 2b: EUROZONE CORE VS PERIPHERY")
    print("=" * 60)

    EZ_CORE = ['DEU', 'NLD', 'AUT', 'FIN', 'BEL']
    EZ_PERIPHERY = ['ITA', 'ESP', 'PRT', 'GRC', 'IRL']

    def make_ez_subsample(df, iso3_list):
        rows = []
        for iso3 in iso3_list:
            if iso3 in EUROZONE_JOIN:
                join_yr = EUROZONE_JOIN[iso3]
                mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
                rows.append(df[mask])
        if not rows:
            return pd.DataFrame()
        return pd.concat(rows, ignore_index=True)

    core_df = make_ez_subsample(df, EZ_CORE)
    peri_df = make_ez_subsample(df, EZ_PERIPHERY)
    print(f"  Core (DEU,NLD,AUT,FIN,BEL): {len(core_df)} obs, "
          f"{core_df['iso3'].nunique()} countries")
    print(f"  Periphery (ITA,ESP,PRT,GRC,IRL): {len(peri_df)} obs, "
          f"{peri_df['iso3'].nunique()} countries")

    results_cp = []

    # Core: Step 1 Z → K/Y
    print("\n--- Core: Baron & Kenny Steps ---")
    core_s1 = run_panel_gls(core_df, 'capital_intensity', Z_VARS + CONTROLS_EZ,
                            'Core S1: Z→K/Y')
    if core_s1: results_cp.append(core_s1)

    # Core: Step 2 Z → CA
    core_s2 = run_panel_gls(core_df, 'ca_gdp', Z_VARS + CONTROLS_EZ,
                            'Core S2: Z→CA')
    if core_s2: results_cp.append(core_s2)

    # Core: Step 3 Z + K/Y → CA
    core_s3 = run_panel_gls(core_df, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS_EZ,
                            'Core S3: Z+K/Y→CA')
    if core_s3: results_cp.append(core_s3)

    # Periphery: Step 1 Z → K/Y
    print("\n--- Periphery: Baron & Kenny Steps ---")
    peri_s1 = run_panel_gls(peri_df, 'capital_intensity', Z_VARS + CONTROLS_EZ,
                            'Peri S1: Z→K/Y')
    if peri_s1: results_cp.append(peri_s1)

    # Periphery: Step 2 Z → CA
    peri_s2 = run_panel_gls(peri_df, 'ca_gdp', Z_VARS + CONTROLS_EZ,
                            'Peri S2: Z→CA')
    if peri_s2: results_cp.append(peri_s2)

    # Periphery: Step 3 Z + K/Y → CA
    peri_s3 = run_panel_gls(peri_df, 'ca_gdp', Z_VARS + ['capital_intensity'] + CONTROLS_EZ,
                            'Peri S3: Z+K/Y→CA')
    if peri_s3: results_cp.append(peri_s3)

    write_table(results_cp, "phase7_core_vs_periphery.md",
                "Eurozone Core vs Periphery: Capital Intensity Mediation")

    # Core/periphery attenuation
    print("\n--- Core vs Periphery Attenuation ---")
    if core_s2 and core_s3:
        core_atten = compute_attenuation(core_s2, core_s3, 'Z_1')
        if core_atten is not None:
            print(f"  Core: Z₁ = {core_s2['Z_1_coef']:.4f} → {core_s3['Z_1_coef']:.4f} "
                  f"(attenuation: {core_atten:.1f}%)")
            attenuation_results['EZ Core'] = {
                'baseline': core_s2['Z_1_coef'],
                'mediated': core_s3['Z_1_coef'],
                'baseline_p': core_s2.get('Z_1_p', np.nan),
                'mediated_p': core_s3.get('Z_1_p', np.nan),
                'attenuation': core_atten,
            }
    if peri_s2 and peri_s3:
        peri_atten = compute_attenuation(peri_s2, peri_s3, 'Z_1')
        if peri_atten is not None:
            print(f"  Periphery: Z₁ = {peri_s2['Z_1_coef']:.4f} → {peri_s3['Z_1_coef']:.4f} "
                  f"(attenuation: {peri_atten:.1f}%)")
            attenuation_results['EZ Periphery'] = {
                'baseline': peri_s2['Z_1_coef'],
                'mediated': peri_s3['Z_1_coef'],
                'baseline_p': peri_s2.get('Z_1_p', np.nan),
                'mediated_p': peri_s3.get('Z_1_p', np.nan),
                'attenuation': peri_atten,
            }

    # Compare first-stage Z → K/Y coefficients
    if core_s1 and peri_s1:
        print(f"\n  First-stage Z₁ → K/Y:")
        print(f"    Core:      {core_s1['Z_1_coef']:.4f} (p={core_s1['Z_1_p']:.4f})")
        print(f"    Periphery: {peri_s1['Z_1_coef']:.4f} (p={peri_s1['Z_1_p']:.4f})")

    # ══════════════════════════════════════════════════════════════════
    # SECTION 3: SUMMARY CROSS-TABULATION
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("SECTION 3: SUPPRESSION SUMMARY TABLE")
    print("=" * 60)

    lines = ["# Suppression Summary: Attenuation of Z₁ by Capital Intensity\n"]
    lines.append("| Sample | Z₁ baseline | p | Z₁ + K/Y | p | Attenuation % | Interpretation |")
    lines.append("|:---|---:|---:|---:|---:|---:|:---|")

    interpretations = {
        'Full': 'Amplification (non-OECD composition)',
        'OECD': 'Suppression (two opposing channels)',
        'Non-OECD': '',
        'Eurozone': '',
        'OECD Float': '',
    }

    for label in ['Full', 'OECD', 'Non-OECD', 'Eurozone', 'OECD Float',
                    'EZ Core', 'EZ Periphery']:
        if label in attenuation_results:
            r = attenuation_results[label]
            b_str = f"{r['baseline']:.2f}{stars(r['baseline_p'])}"
            m_str = f"{r['mediated']:.2f}{stars(r['mediated_p'])}"
            a_str = f"{r['attenuation']:.1f}%"
            interp = interpretations.get(label, '')
            # Auto-generate interpretation if empty
            if not interp:
                if r['attenuation'] > 50:
                    interp = 'Mediation'
                elif r['attenuation'] > 0:
                    interp = 'Partial mediation'
                elif r['attenuation'] < -50:
                    interp = 'Amplification'
                else:
                    interp = 'Weak/null'
            lines.append(f"| {label} | {b_str} | {r['baseline_p']:.4f} | "
                         f"{m_str} | {r['mediated_p']:.4f} | {a_str} | {interp} |")

    lines.append("\n*Attenuation = (baseline − mediated) / baseline × 100.*")
    lines.append("*Positive = mediation (K/Y absorbs Z effect). "
                 "Negative = amplification (K/Y acts as suppressor).*")
    lines.append("*Baron & Kenny (1986) framework. Panel GLS with country and year FE.*")

    path = OUT_TABLES / "phase7_suppression_summary.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    print("\n" + "=" * 70)
    print("PHASE 7 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
